{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Anaconda3\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import glob as gb\n",
    "\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "\n",
    "import data_utils\n",
    "from losses import adversarial_loss, generator_loss\n",
    "from model import generator_model, discriminator_model, generator_containing_discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(batch_size, epoch_num):\n",
    "    # Note the x(blur) in the second, the y(full) in the first\n",
    "    y_train, x_train = data_utils.load_data(data_type='train')\n",
    "\n",
    "    # GAN\n",
    "    g = generator_model()\n",
    "    d = discriminator_model()\n",
    "    d_on_g = generator_containing_discriminator(g, d)\n",
    "\n",
    "    # compile the models, use default optimizer parameters\n",
    "    # generator use adversarial loss\n",
    "    g.compile(optimizer='adam', loss=generator_loss)\n",
    "    # discriminator use binary cross entropy loss\n",
    "    d.compile(optimizer='adam', loss='binary_crossentropy')\n",
    "    # adversarial net use adversarial loss\n",
    "    d_on_g.compile(optimizer='adam', loss=adversarial_loss)\n",
    "\n",
    "    for epoch in range(epoch_num):\n",
    "        print('epoch: ', epoch + 1, '/', epoch_num)\n",
    "        print('batches: ', int(x_train.shape[0] / batch_size))\n",
    "\n",
    "        for index in range(int(x_train.shape[0] / batch_size)):\n",
    "            # select a batch data\n",
    "            image_blur_batch = x_train[index * batch_size:(index + 1) * batch_size]\n",
    "            image_full_batch = y_train[index * batch_size:(index + 1) * batch_size]\n",
    "            generated_images = g.predict(x=image_blur_batch, batch_size=batch_size)\n",
    "\n",
    "            # output generated images for each 30 iters\n",
    "            if (index % 30 == 0) and (index != 0):\n",
    "                data_utils.generate_image(image_full_batch, image_blur_batch, generated_images,\n",
    "                                          'result/interim/', epoch, index)\n",
    "\n",
    "            # concatenate the full and generated images,\n",
    "            # the full images at top, the generated images at bottom\n",
    "            x = np.concatenate((image_full_batch, generated_images))\n",
    "\n",
    "            # generate labels for the full and generated images\n",
    "            y = [1] * batch_size + [0] * batch_size\n",
    "\n",
    "            # train discriminator\n",
    "            d_loss = d.train_on_batch(x, y)\n",
    "            print('batch %d d_loss : %f' % (index + 1, d_loss))\n",
    "\n",
    "            # let discriminator can't be trained\n",
    "            d.trainable = False\n",
    "\n",
    "            # train adversarial net\n",
    "            d_on_g_loss = d_on_g.train_on_batch(image_blur_batch, [1] * batch_size)\n",
    "            print('batch %d d_on_g_loss : %f' % (index + 1, d_on_g_loss))\n",
    "\n",
    "            # train generator\n",
    "            g_loss = g.train_on_batch(image_blur_batch, image_full_batch)\n",
    "            print('batch %d g_loss : %f' % (index + 1, g_loss))\n",
    "\n",
    "            # let discriminator can be trained\n",
    "            d.trainable = True\n",
    "\n",
    "            # output weights for generator and discriminator each 30 iters\n",
    "            if (index % 30 == 0) and (index != 0):\n",
    "                g.save_weights('weight/generator_weights.h5', True)\n",
    "                d.save_weights('weight/discriminator_weights.h5', True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(batch_size):\n",
    "    # Note the x(blur) in the second, the y(full) in the first\n",
    "    y_test, x_test = data_utils.load_data(data_type='test')\n",
    "    g = generator_model()\n",
    "    g.load_weights('weight/generator_weights.h5')\n",
    "    generated_images = g.predict(x=x_test, batch_size=batch_size)\n",
    "    data_utils.generate_image(y_test, x_test, generated_images, 'result/finally/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_pictures(batch_size):\n",
    "    data_path = 'data/test/*.jpeg'\n",
    "    images_path = gb.glob(data_path)\n",
    "    data_blur = []\n",
    "    for image_path in images_path:\n",
    "        image_blur = Image.open(image_path)\n",
    "        data_blur.append(np.array(image_blur))\n",
    "\n",
    "    data_blur = np.array(data_blur).astype(np.float32)\n",
    "    data_blur = data_utils.normalization(data_blur)\n",
    "\n",
    "    g = generator_model()\n",
    "    g.load_weights('weight/generator_weights.h5')\n",
    "    generated_images = g.predict(x=data_blur, batch_size=batch_size)\n",
    "    generated = generated_images * 127.5 + 127.5\n",
    "    for i in range(generated.shape[0]):\n",
    "        image_generated = generated[i, :, :, :]\n",
    "        Image.fromarray(image_generated.astype(np.uint8)).save('result/test/' + str(i) + '.png')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train(batch_size=2, epoch_num=10)\n",
    "test(4)\n",
    "test_pictures(2)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
